% NOTE: THIS NEEDS UPDATING TO USE 
%[p,t,st] = kruskalwallis(MPG,Origin,'off');
%[c,m,h,nms] = multcompare(st,'display','on');
% This takes care of trying to test significance for more than 2
% categories.  Stephen suggests using classify also to see if you can
% classify the category given the population response.
%
% Analyze the data from the spiking network for some task consisting of
% trials of equal duration.
%
%
%  function [rs sig_cells] = analyzepatterns( area_name, type_name, phase_start_msec, phase_end_msec , start_time_msecs, end_time_msecs,significance_level, save_seconds)
% 
%  phase_start_msec, phase_end_msec:   the start and
%		end times (inclusive) within the trial to analyze the data.  Start counting at 0 for the first msec.
%  start_time_msec, end_time_msec: start and end times of the experiments.
%
%  save_seconds = the number of seconds saved in 1 spike*.dat file; must
%   be an integer.
%  RETURNS:
%	matrix of probabilities from ranksum, comparing spike counts of neurons during the specified task phase time
%	for significantly different medians.  rs(neuron, pattern i , pattern j);
%
%  EXAMPLE:
%
% analyzePatterns('S1','p4',0,249,10000,29999) 
%
%  will analyze patterns
%  which are delivered between 10000 msec and 29999 msec into the
%  simulation. During each trial, the spikes will be analyzed only between
%  reletive time 0 and 249. 
%  The file patternhist.dat looks like this where each line is a trial and
%  the first entry is the start time of a trial in msec and the second
%  number is the pattern number in range [0,N], where N is the max number
%  of patterns used in the experiments (no gaps.)
% 
% 
% 0 0
% 250 1
% 500 2
% 750 3
% 1000 0
% 1250 1
% 1500 2
% 1750 3
% 2000 0
% 2250 1
% 2500 2
% 2750 3
% ...
% 
%
% If area_name == 'mfr', then the rest of the params are unnecessary; it
% will be assumed that the mfr structure is precalculated and rows and cols are set.
% 
% mfr is defined as follows: mfr(pattern).spikes(neuron,trial) =
%   the spikes per second that 'neuron' fires in response to 'pattern' on
%   'trial'.  Each pattern can have a different number of pattern
%   presentations.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% 05/25/2011: RGM
% To expedite our testing, we omitted any analyses that were not 
% necessary for our current work.
% In addition, w inverted all patterned images so that they will match our
% neuronal display.

function analyzePatternsInverted( area_name, type_name, phase_start_msec, phase_end_msec , start_time_msecs, end_time_msecs, full_analysis, significance_level, save_seconds)

%global samples cats
global mfr rows cols
msecs_per_sec = 1000;

if nargin < 7
    full_analysis = 0;
end

if nargin < 8
    significance_level=0.05;
    save_seconds = 1;
else 
    if nargin < 9
        save_seconds = 1;
    end
end

if strcmp(area_name,'mfr')
    % global variable mfr is defined.  Do not read spike data.
    area_name='network';
    type_name='neuron';
    phase_start_msec = 0;
    phase_end_msec = 1000;

    N=rows*cols
    
    first = 1; % first neuron number is not relevant here.
    
    num_patterns = length(mfr);
    for i = 1:num_patterns
        count(i) = size( mfr(i).spikes,2);
    end
    
else    
    mfr=[];

    load ('patternhist.dat');

    phase_start_msec = phase_start_msec + 1;
    phase_end_msec = phase_end_msec + 1;

    read_groups; % Assumes groups.dat has the neural area data

    index = [];
    for i=1:length(areas)
        if strcmp( deblank(areas{i}), area_name) && strcmp( type_name, deblank(celltype{i})) 
            index = i
            break;
        end
    end
    if isempty(index)
        disp([area_name ' ' type_name ' is not in groups.dat.']);
        return;
    end

    first = neuron_id(index,1) 
    last  = neuron_id(index,2) 

    num_neurons = max(neuron_id(:,2))+1

    first = first + 1;
    last = last + 1;
    rows = la(index);
    cols = lb(index);
    N=rows*cols



    num_patterns = max( patternhist(:,2))+1


    count = zeros( 1, num_patterns );


    % find the file we need to open
    %filename = ['sdat' numTOstr( ceil( (start_time_msecs+phase_start_msec)/(save_seconds*1000))*save_seconds,6) '.dat' ]
    %sdat = load(filename);
    %S=sparse( sdat(:,2)+1, sdat(:,1)+1, 1);
    filename = '';

    trial_indices_to_analyze = find(patternhist(:,1)>=start_time_msecs & patternhist(:,1)<end_time_msecs);

    for i = trial_indices_to_analyze'

        t = patternhist( i, 1 );

        % find the file we need to open
        start_file = ceil( ( t+phase_start_msec)/(save_seconds*1000));
        stop_file = ceil( ( t+phase_end_msec-1)/(save_seconds*1000));
        
        sdat = [];
        for file = 0 : (stop_file-start_file)
            filename = ['spikes' num2str( (start_file + file)*save_seconds*1000, 15 ) '.dat']
            sdat_temp = load(filename);
            sdat = [sdat; sdat_temp];
        end
        
        max_time_msec = ceil( (t+phase_start_msec)/(save_seconds*1000) )*(stop_file-start_file+1)*msecs_per_sec
        S=sparse( sdat(:,2)+1, sdat(:,1)+1, 1, max(num_neurons,max(sdat(:,1)+1)), max(max_time_msec,max(sdat(:,1))) );

        
        current_pattern = patternhist( i , 2 )+1;  
        count(current_pattern) = count(current_pattern) + 1;
        c=count(current_pattern);
        
        % Gather mean rates for stats.
        first
        last 
        (t+phase_start_msec)
        (t+phase_end_msec)
        disp('****************************')
        mfr(current_pattern).spikes(:, c ) = full(sum( S( first:last, (t+phase_start_msec):(t+phase_end_msec-1)), 2));
        current_pattern
    end
end

count

k=0;
figure(100+k); k= k + 1;


disp('Mean firing rate of cells to each pattern');
meanresponses=zeros(rows,cols,num_patterns);
for i = 1:num_patterns


	%figure(100+k); k= k + 1;
    side = ceil(sqrt(num_patterns));
    subplot(side,side,i);
    
    meanresponses(:,:,i)=rot90(reshape( mean(mfr(i).spikes,2)*1000/(phase_end_msec-phase_start_msec), rows, cols )); % ???
	imagesc(rot90(meanresponses(:,:,i)'));
	%colormap(gray);
    colorbar;
    set(gca,'XTickLabel',{});
	%title(['Mean rate (sp/sec) of cells in area ' area_name type_name ' to pattern ' num2str(i) ]);
	
end


figure(100+k); k= k + 1;

imagesc(rot90(mean(meanresponses,3)'));
colorbar
title('Mean firing rate for all neurons across all patterns','FontSize',14);
drawnow

figure(100+k); k= k + 1;
sortedmfr=sort(meanresponses,3,'descend');
sortedmfr=reshape(sortedmfr,rows*cols,num_patterns);
%plot(mean(sortedmfr));
errorbar(mean(sortedmfr),std(sortedmfr));
title('Grand tuning curve over all neurons (mfr +/- std.)','FontSize',14);
ylabel('Mean firing rate (sp/sec)','FontSize',12);
xlabel('pattern number sorted from best to worst','FontSize',12);
drawnow

%{ 
RGM: Not needed for our analysis
figure(100+k); k= k + 1;
imagesc(sortedmfr);
%}

%{ 
RGM: Not needed for our analysis
figure(100+k); k= k + 1;
b=bitsinfo(sortedmfr,2.5:5:(max(sortedmfr(:,1))+2.49)); % bins of ~ 5 sp/sec.
hist(b,20);
title('histogram of bits of information conveyed by neurons','FontSize',14);
drawnow
%}

figure(100+k);k=k+1;
[best_response best_indices]=max(meanresponses,[],3);
best_indices(best_response == 0) = 0;
imagesc(rot90(best_indices'), [0 num_patterns]); % start scale at 0 for comparison with significance plot
colorbar
axis off
title('Optimal pattern for each cell','FontSize',14);
drawnow

%{ 
RGM: Not needed for our analysis
figure(100+k);k=k+1;
hist(best_indices(:),num_patterns);
title('Histogram of neurons responding best to each pattern','FontSize',14);
drawnow
%}

% See if we can classify the samples given the network response vector
train_classifier = true;
if train_classifier

    samples = [];
    cats=[];
    testing_samples = [];
    testing_cats=[];

    for i = 1:num_patterns

        training_patterns = floor(2/3*count(i));
        testing_patterns = count(i)-training_patterns;

        samples =[samples; mfr( i ).spikes(:,1:training_patterns)'];
        cats = [cats ones(1,training_patterns)*i];

        testing_samples = [testing_samples; mfr(i).spikes(:,training_patterns+1:end)'];
        testing_cats = [testing_cats ones(1,testing_patterns)*i];
    end
    
    % Try a backprop classifier
    T = full(ind2vec(cats)) ; % turn category indices into appropriate outputs.

    bpnet=newlind(samples',T);
    Y=sim(bpnet,samples');

    [m Yc] = max(Y);

    percent_correct_bp = sum(Yc==(vec2ind(T)))/length(Yc)*100
    cMat1 = confusionmat(cats,Yc); % the confusion matrix

    % now on test set.
    T = full(ind2vec(testing_cats)) ; % turn category indices into appropriate outputs.
    Y=sim(bpnet,testing_samples');

    [m Yc] = max(Y);

    percent_correct_bp = sum(Yc==(vec2ind(T)))/length(Yc)*100
    cMat1 = confusionmat(testing_cats,Yc); % the confusion matrix
end


%RGM: Continue analysis only if user sets the full_analysis argument flag
if full_analysis
    % Show significantly preferential responses to each pattern.
    for i = 1:num_patterns
        for j = 1:num_patterns
            if ( count(j) > 1 && count(i) > 1)
            for n = 1:N
                [junk significant]=ranksum( mfr( i ).spikes(n, : ) , mfr( j ).spikes(n, : ) ,significance_level);
                rs(n,i,j) = significant && mean(mfr( i ).spikes(n, : ))  > mean(mfr( j ).spikes(n, : )); 
    %if sum( mfr( i ).spikes(n, :) ) > 0
    % [i j]
    % mfr( i ).spikes(n, : ) 
    % mfr( j ).spikes(n, : ) 
    %end
            end
            num_significant(i,j)=sum(rs(:,i,j));
            %if j>i
            %	figure(100+k);
            %	imagesc(reshape( rs(:,i,j), rows, cols ))
            %	colormap(gray);
            %	title(['Cells in area ' areaname ' that distinguish pattern ' num2str(i) ' from pattern ' num2str(j) ' statistically']);
            %	k=k+1;
            %end
            end
        end
        disp('.');
    end

    %disp('number of cells which distinguish pattern "row" from pattern "col"');
    %num_significant

    discriminators = zeros(1,num_patterns);
    merged_discriminator_map = zeros(N,1);

    for i = 1:num_patterns
        temp = ones(N,1);
        for j = 1:num_patterns

            if i ~= j
                temp = temp & rs(:,i,j);
            end

        end

        merged_discriminator_map = merged_discriminator_map + i*temp;
        discriminators(i) = sum(temp);

        figure(100+k); k= k + 1;

        %subplot(num_patterns,1,i);
        imagesc(reshape( temp, rows, cols )');
        colormap(gray);
        title({['Cells in area ' area_name type_name ' that distinguish pattern ' num2str(i)],[ ' from ALL others statistically (p< ' num2str(significance_level) ', Wilcoxon RS); t=' num2str(phase_start_msec) ' ' num2str(phase_end_msec) ' and respond to maximally to it.' ]},'FontSize',14);
        size(temp);
        %disp('Significant cell numbers (start counting at 0 as in C): ');
        sig_cells(i).list = (find(temp)+first-2)';
    end

    %{ 
    RGM: Not needed for our analysis
    figure(100+k);k=k+1;
    bar(discriminators);
    title({['Histogram of cells in area ' area_name type_name ' that distinguish one pattern from ALL others'],['statistically (p< ' num2str(significance_level) ', Wilcoxon RS); t=' num2str(phase_start_msec) ' ' num2str(phase_end_msec) ' and respond to maximally to it.' ]},'FontSize',14);
    %}

    figure(100+k);k=k+1;
    imagesc(reshape( merged_discriminator_map, rows, cols )');
    title({['Cells in area ' area_name type_name ' that distinguish one pattern from ALL others'],['statistically (p< ' num2str(significance_level) ', Wilcoxon RS); t=' num2str(phase_start_msec) ' ' num2str(phase_end_msec) ' and respond to maximally to it.' ]},'FontSize',14);
    axis off
    colorbar
    
    %{
    merged_discriminator_map
    for i = find(merged_discriminator_map')
       disp(['Responses for unit ' num2str(i)]);
       for j = 1:num_patterns 
          disp([j mfr(j).spikes(i,:)]); 
       end
    end
    %}
end %if full_analysis

